#  + - - - - - -
#  |  run with: snakemake final_rule --cores 4 --resources nvidia_gpu=1 --config ...
#  + - - - 

import os
import yaml

MODELS = ["Rich", "Muon", "GlobalPID-im", "GlobalPID-nm"] 
PARTS = ["muon", "pion", "kaon", "proton"]

# +-------------------+
# |   Initial setup   |
# +-------------------+

configfile: "config/snakemake_config.yml"

with open("config/directories.yml") as file:
    config_dir = yaml.full_load(file)

data_dir = config_dir["data_dir"]
models_dir = config_dir["models_dir"]

wildcard_constraints:
    model = "|".join([m for m in MODELS]),
    part = "|".join([p for p in PARTS]),

# +-----------------------------+
# |   PID GAN models training   |
# +-----------------------------+

rule pid_gan_training:
    input:
        data = expand(
            os.path.join(data_dir, "{model}-{part}-{sample}-trainset.npz"),
            model=MODELS,
            part=PARTS,
            sample=config["data_sample"],
        ),
        tX = expand(
            os.path.join(models_dir, "{model}_{part}_models/tX_{sample}.pkl"),
            model=MODELS,
            part=PARTS,
            sample=config["data_sample"],
        ),
        tY = expand(
            os.path.join(models_dir, "{model}_{part}_models/tY_{sample}.pkl"),
            model=MODELS,
            part=PARTS,
            sample=config["data_sample"],
        ),

    params:
        num_epochs = config["num_epochs"],
        chunk_size = config["gan_chunk_size"],
        train_ratio = config["gan_train_ratio"],

    output:
        directory(os.path.join(models_dir, "{model}_{part}_models/latest_{model}_{part}_{sample}_gan")),

    resources:
        nvidia_gpu = 1,

    log:
        "logs/gan-training_{model}_{part}_{sample}.log"

    shell:
        "python train_GAN_{wildcards.model}.py -p {wildcards.part} -E {params.num_epochs} -C {params.chunk_size} -T {params.train_ratio} -D {wildcards.sample} --latest > {log}"

# +--------------------------------+
# |   isMuon ANN models training   |
# +--------------------------------+

rule im_ann_training:
    input:
        data = expand(
            os.path.join(data_dir, "isMuon-{part}-{sample}-trainset.npz"),
            model=MODELS,
            part=PARTS,
            sample=config["data_sample"],
        ),
        tX = expand(
            os.path.join(models_dir, "isMuon_{part}_models/tX_{sample}.pkl"),
            model=MODELS,
            part=PARTS,
            sample=config["data_sample"],
        ),
    
    params:
        num_epochs = config["num_epochs"],
        chunk_size = config["ismuon_chunk_size"],
        train_ratio = config["ismuon_train_ratio"],

    output:
        directory(os.path.join(models_dir, "isMuon_{part}_models/latest_isMuon_{part}_{sample}_ann")),

    resources:
        nvidia_gpu = 1,

    log:
        "logs/ann-training_isMuon_{part}_{sample}.log"

    shell:
        "python train_ANN_isMuon.py -p {wildcards.part} -E {params.num_epochs} -C {params.chunk_size} -T {params.train_ratio} -D {wildcards.sample} --latest > {log}"

# +------------------+
# |    Final rule    |
# +------------------+

rule final_rule:
    input:
        pid_gan_models = expand(
            os.path.join(models_dir, "{model}_{part}_models/latest_{model}_{part}_{sample}_gan"),
            model=MODELS,
            part=PARTS,
            sample=config["data_sample"],
        ),
        im_ann_model = expand(
            os.path.join(models_dir, "isMuon_{part}_models/latest_isMuon_{part}_{sample}_ann"),
            part=PARTS,
            sample=config["data_sample"],
        ),

    output:
        ".all_rules_completed"

    shell:
        "touch {output}"
